{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import collections\n",
    "\n",
    "start_token = 'G'\n",
    "end_token = 'E'\n",
    "batch_size = 64"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 数据预处理部分"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def process_poems(file_name):\n",
    "    poems = []\n",
    "    with open(file_name, \"r\", encoding='utf-8', ) as f:\n",
    "        for line in f.readlines():\n",
    "            try:\n",
    "                title, content = line.strip().split(':')\n",
    "                content = content.replace(' ', '')\n",
    "                if '_' in content or '(' in content or '（' in content or '《' in content or '[' in content or \\\n",
    "                                start_token in content or end_token in content:\n",
    "                    continue\n",
    "                if len(content) < 5 or len(content) > 80:\n",
    "                    continue\n",
    "                content = start_token + content + end_token\n",
    "                poems.append(content)\n",
    "            except ValueError as e:\n",
    "                pass\n",
    "    # 按诗的字数排序\n",
    "    poems = sorted(poems, key=lambda line: len(line))\n",
    "    # 统计每个字出现次数\n",
    "    all_words = []\n",
    "    for poem in poems:\n",
    "        all_words += [word for word in poem]  \n",
    "    counter = collections.Counter(all_words)  # 统计词和词频。\n",
    "    count_pairs = sorted(counter.items(), key=lambda x: -x[1])  # 排序\n",
    "    words, _ = zip(*count_pairs)\n",
    "    words = words[:len(words)] + (' ',)\n",
    "    word_int_map = dict(zip(words, range(len(words))))\n",
    "    poems_vector = [list(map(word_int_map.get, poem)) for poem in poems]\n",
    "    return poems_vector, word_int_map, words"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### rnn_lstm model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def rnn_model(model, input_data, output_data, vocab_size, rnn_size=128, num_layers=2, batch_size=64,\n",
    "              learning_rate=0.01):\n",
    "    end_points = {}\n",
    "    # 构建RNN基本单元RNNcell\n",
    "    if model == 'rnn':\n",
    "        cell_fun = tf.contrib.rnn.BasicRNNCell\n",
    "    elif model == 'gru':\n",
    "        cell_fun = tf.contrib.rnn.GRUCell\n",
    "    else:\n",
    "        cell_fun = tf.contrib.rnn.BasicLSTMCell\n",
    "    #？？？？？？？？？？？？？？？？？？？？？？\n",
    "    # 每层128个小单元，一共有两层，输出的Ct 和 Ht 要分开放到两个tuple中\n",
    "    # 在下面补全代码 \n",
    "    #################################################\n",
    "    cell = cell_fun(  )\n",
    "    cell = tf.contrib.rnn.MultiRNNCell(   )\n",
    "    #################################################\n",
    "    # 如果是训练模式，output_data不为None，则初始状态shape为[batch_size * rnn_size]\n",
    "    # 如果是生成模式，output_data为None，则初始状态shape为[1 * rnn_size]\n",
    "    if output_data is not None:\n",
    "        initial_state = cell.zero_state(batch_size, tf.float32)\n",
    "    else:\n",
    "        initial_state = cell.zero_state(1, tf.float32)\n",
    "\n",
    "    # 构建隐层\n",
    "    with tf.device(\"/cpu:0\"):\n",
    "        embedding = tf.Variable(tf.random_uniform([vocab_size + 1, rnn_size], -1.0, 1.0),name = 'embedding')\n",
    "        inputs = tf.nn.embedding_lookup(embedding, input_data)\n",
    "    #？？？？？？？？？？？？？？？？？？？？？？？？？？\n",
    "    ####################################################    \n",
    "    outputs, last_state = tf.nn.dynamic_rnn(    )# 填写里面的内容\n",
    "    ######################################################\n",
    "    output = tf.reshape(outputs, [-1, rnn_size])\n",
    "    \n",
    "    weights = tf.Variable(tf.truncated_normal([rnn_size, vocab_size + 1]))\n",
    "    bias = tf.Variable(tf.zeros(shape=[vocab_size + 1]))\n",
    "    logits = tf.nn.bias_add(tf.matmul(output, weights), bias=bias) # 一层全连接\n",
    "\n",
    "\n",
    "    if output_data is not None: # 训练模式\n",
    "        labels = tf.one_hot(tf.reshape(output_data, [-1]), depth=vocab_size + 1)\n",
    "        loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)\n",
    "        total_loss = tf.reduce_mean(loss)\n",
    "        train_op = tf.train.AdamOptimizer(learning_rate).minimize(total_loss)  # 优化器用的 adam\n",
    "        end_points['initial_state'] = initial_state\n",
    "        end_points['output'] = output\n",
    "        end_points['train_op'] = train_op\n",
    "        end_points['total_loss'] = total_loss\n",
    "        end_points['loss'] = loss\n",
    "        end_points['last_state'] = last_state\n",
    "    else: # 生成模式\n",
    "        prediction = tf.nn.softmax(logits)\n",
    "        end_points['initial_state'] = initial_state\n",
    "        end_points['last_state'] = last_state\n",
    "        end_points['prediction'] = prediction\n",
    "    return end_points"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 训练模型部分"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def run_training():\n",
    "    # 处理数据集\n",
    "    poems_vector, word_to_int, vocabularies = process_poems('./poems.txt')\n",
    "    # 生成batch\n",
    "    batches_inputs, batches_outputs = generate_batch(64, poems_vector, word_to_int)\n",
    "\n",
    "    input_data = tf.placeholder(tf.int32, [batch_size, None])\n",
    "    output_targets = tf.placeholder(tf.int32, [batch_size, None])\n",
    "    # 构建模型\n",
    "    end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=len(\n",
    "        vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=0.01)\n",
    "\n",
    "    saver = tf.train.Saver()\n",
    "    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())\n",
    "    with tf.Session() as sess:\n",
    "        sess.run(init_op)\n",
    "        for epoch in range(50):\n",
    "            n = 0\n",
    "            n_chunk = len(poems_vector) // batch_size\n",
    "            for batch in range(n_chunk):\n",
    "                loss, _, _ = sess.run([\n",
    "                    end_points['total_loss'],\n",
    "                    end_points['last_state'],\n",
    "                    end_points['train_op']\n",
    "                ], feed_dict={input_data: batches_inputs[n], output_targets: batches_outputs[n]})\n",
    "                n += 1\n",
    "                print('[INFO] Epoch: %d , batch: %d , training loss: %.6f' % (epoch, batch, loss))\n",
    "        saver.save(sess, './poem_generator')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 生成 诗歌部分"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def gen_poem(begin_word):\n",
    "    batch_size = 1\n",
    "    poems_vector, word_int_map, vocabularies = process_poems('./poems.txt')\n",
    "\n",
    "    input_data = tf.placeholder(tf.int32, [batch_size, None])\n",
    "\n",
    "    end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=len(\n",
    "        vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=0.01)\n",
    "    # 如果指定开始的字\n",
    "    if begin_word:\n",
    "        word = begin_word\n",
    "    else:\n",
    "        word = to_word(predict, vocabularies)\n",
    "        \n",
    "    saver = tf.train.Saver(tf.global_variables())\n",
    "    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())\n",
    "\n",
    "    with tf.Session() as sess:\n",
    "        sess.run(init_op)\n",
    "        saver.restore(sess, './poem_generator')# 恢复之前训练好的模型 \n",
    "        poem = ''\n",
    "        #???????????????????????????????????????\n",
    "        # 下面部分代码主要功能是根据指定的开始字符来生成诗歌\n",
    "        #########################################\n",
    "        \n",
    "        \n",
    "        #########################################\n",
    "        return poem\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 其他的一些处理函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def generate_batch(batch_size, poems_vec, word_to_int):\n",
    "    # 每次取64首诗进行训练\n",
    "    n_chunk = len(poems_vec) // batch_size\n",
    "    x_batches = []\n",
    "    y_batches = []\n",
    "    for i in range(n_chunk):\n",
    "        start_index = i * batch_size\n",
    "        end_index = start_index + batch_size\n",
    "\n",
    "        batches = poems_vec[start_index:end_index]\n",
    "        # 找到这个batch的所有poem中最长的poem的长度\n",
    "        length = max(map(len, batches))\n",
    "        # 填充一个这么大小的空batch，空的地方放空格对应的index标号\n",
    "        x_data = np.full((batch_size, length), word_to_int[' '], np.int32)\n",
    "        for row in range(batch_size):\n",
    "            x_data[row, :len(batches[row])] = batches[row]\n",
    "        y_data = np.copy(x_data)\n",
    "        y_data[:, :-1] = x_data[:, 1:]\n",
    "        \"\"\"\n",
    "        x_data             y_data\n",
    "        [6,2,4,6,9]       [2,4,6,9,9]\n",
    "        [1,4,2,8,5]       [4,2,8,5,5]\n",
    "        \"\"\"\n",
    "        x_batches.append(x_data)\n",
    "        y_batches.append(y_data)\n",
    "    return x_batches, y_batches\n",
    "\n",
    "def to_word(predict, vocabs):# 预测的结果转化成汉字\n",
    "    sample = np.argmax(predict)\n",
    "    if sample > len(vocabs):\n",
    "        sample = len(vocabs) - 1\n",
    "    return vocabs[sample]\n",
    "def pretty_print_poem(poem):#  令打印的结果更工整\n",
    "    poem_sentences = poem.split('。')\n",
    "    for s in poem_sentences:\n",
    "        if s != '' and len(s) > 10:\n",
    "            print(s + '。')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 主函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "print('[INFO] train tang poem...')\n",
    "run_training() # 训练模型\n",
    "print('[INFO] write tang poem...')\n",
    "poem2 = gen_poem('月')# 生成诗歌\n",
    "print(\"#\" * 25)\n",
    "pretty_print_poem(poem2)\n",
    "print('#' * 25)\n",
    "#训练模型时间比较长，训练模型完成后每次生成诗歌的时，不需要再次训练 ，可以注销上面的 run_training()。生成部分执行速度很快"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [conda root]",
   "language": "python",
   "name": "conda-root-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
